{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright 2020 The Google Research Authors.\n",
    "\n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "you may not use this file except in compliance with the License.\n",
    "You may obtain a copy of the License at\n",
    "\n",
    "    http://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "Unless required by applicable law or agreed to in writing, software\n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "See the License for the specific language governing permissions and\n",
    "limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Activation Clustering Model: Training\n",
    "\n",
    "This notebook shows how to train an activation clustering model from a trained baseline Keras model.  Here we use a ResNet classification model trained on the CIFAR-10 dataset as an example.  The model is included as `model.h5`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "\n",
    "from activation_clustering import ac_model, utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The same dataset preprocessing as used in the baseline cifar10 model training.\n",
    "def input_fn(batch_size, ds, label_key='label'):\n",
    "    dataset = ds.batch(batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)\n",
    "\n",
    "    def interface(batch):\n",
    "        features = tf.cast(batch['image'], tf.float32) / 255     \n",
    "        labels = batch[label_key]\n",
    "\n",
    "        return features, labels\n",
    "\n",
    "    return dataset.map(interface)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train an activation clustering model from a baseline model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load the baseline model for the actication clustering model to calculate activations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.models.load_model('model.h5')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Activation clustering model's configurations.  The first entry in each pair is a layer name of the baseline model, whose output activations will be clustered.  The second entry is a dict with key `n_clusters` specifying the number of clusters.\n",
    "\n",
    "We use deep embedding clustering (DEC) as the clustering algorithm in this implementation, which has several other parameters that you can expose by modifying the `activation_clustering` library and configure here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clustering_config = [\n",
    "    ('activation', {'n_clusters': 15}),\n",
    "    ('activation_18', {'n_clusters': 15}),\n",
    "    ('activation_36', {'n_clusters': 15}),\n",
    "    ('activation_54', {'n_clusters': 15})\n",
    "]\n",
    "\n",
    "# Uncomment this for shorter training time for debugging/test runs.\n",
    "# clustering_config = [\n",
    "#     ('activation', {'n_clusters': 10}),\n",
    "#     ('activation_54', {'n_clusters': 10, 'filters': [16, 16, 16, 8]})\n",
    "# ]\n",
    "\n",
    "work_dir = 'new_work_dir'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_acm = ac_model.ACModel(model, clustering_config, work_dir=work_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Calling `build_clustering_models` creates clustering models, one for each specified activation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_acm.build_clustering_models()\n",
    "new_acm.clustering_models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds = tfds.load(\n",
    "    'cifar10:3.*.*',\n",
    "    shuffle_files=False,\n",
    "    split='train'\n",
    ")\n",
    "\n",
    "test_ds = tfds.load(\n",
    "    'cifar10:3.*.*',\n",
    "    shuffle_files=False,\n",
    "    split='test'\n",
    ")\n",
    "\n",
    "# # Uncommend this to use just a portion of data in this example for shorter training time.\n",
    "# train_ds = tfds.load(\n",
    "#     'cifar10:3.*.*',\n",
    "#     shuffle_files=False,\n",
    "#     split='train[:10%]'\n",
    "# )\n",
    "\n",
    "# test_ds = tfds.load(\n",
    "#     'cifar10:3.*.*',\n",
    "#     shuffle_files=False,\n",
    "#     split='test[:10%]'\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Cache the activations to make it easier to iterate.\n",
    "\n",
    "batch_size = 500\n",
    "\n",
    "ds = input_fn(batch_size, train_ds)\n",
    "new_acm.cache_activations(ds, tag='train')\n",
    "del ds\n",
    "\n",
    "ds = input_fn(batch_size, test_ds)\n",
    "new_acm.cache_activations(ds, tag='test')\n",
    "del ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "activations_dict = new_acm.load_activations_dict(\n",
    "    activations_filename=work_dir+'/activations/activations_train.npz')\n",
    "\n",
    "test_activations_dict = new_acm.load_activations_dict(\n",
    "    activations_filename=work_dir+'/activations/activations_test.npz')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for k, v in activations_dict.items():\n",
    "    print(k, v.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Here we use a small number of epochs/iterations for shorter training time.\n",
    "# The activation clustering training loop handles model saving in its `work_dir`.\n",
    "\n",
    "epochs = 15\n",
    "maxiter = 980\n",
    "\n",
    "# # Uncomment this for shorter training time\n",
    "# epochs = 2\n",
    "# maxiter = 280\n",
    "\n",
    "new_acm.fit(activations_dict=activations_dict, epochs=epochs, maxiter=maxiter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
